#!/usr/bin/env python3
"""
sample_gauge_fields.py

Generate gauge‑field configurations for a given gauge group on an L×L periodic
lattice using a deterministic function of the flip counts and pivot parameters.

This implementation replaces the previous dependence on the external
``vol4_discrete_gauge_wilson_loop`` module.  Rather than falling back to
placeholder random unitary matrices when data is missing, it strictly
derives gauge link phases from the flip counts via a logistic transform and
the provided pivot parameters.  A small deterministic perturbation, seeded
by the trial index, ensures variability across trials without introducing
stochastic noise.

The per‑link scalar potential for link ``i`` is computed as:

    logistic_i = 1 / (1 + exp(-k*(n_i - n0)))
    theta_i    = a + b * logistic_i
    theta_i    *= 1.0 + 0.01 * ξ_i

where ``n_i`` is the flip count, ``(a, b, k, n0)`` are loaded from
``pivot_params.txt``, and ``ξ_i`` is a standard normal random variable drawn
from a deterministic pseudo‑random generator seeded by the trial index.  The
complex link matrix is then ``exp(1j*theta_i) * I_d`` for gauge group of
dimension ``d`` (1 for U1, 2 for SU2, 3 for SU3).

Results are written to ``--output-dir`` with names ``<group>_cfg_<trial>.npy``.
"""

import argparse
import os
import numpy as np


def load_pivot_params(path: str) -> dict:
    """Load pivot parameters a, b, logistic_k, logistic_n0 from a whitespace file."""
    params = {}
    with open(path) as f:
        for token in f.read().split():
            if "=" in token:
                key, val = token.split("=", 1)
                params[key] = float(val)
    required = {"a", "b", "logistic_k", "logistic_n0"}
    if not required.issubset(params):
        missing = required - params.keys()
        raise ValueError(f"Missing pivot parameters: {missing}")
    return params


def logistic_transform(x: np.ndarray, k: float, n0: float) -> np.ndarray:
    """Apply a logistic function elementwise."""
    return 1.0 / (1.0 + np.exp(-k * (x - n0)))


def compute_theta(flip_counts: np.ndarray, pivot_params: dict) -> np.ndarray:
    """
    Compute the per‑link base angle from flip counts.

    The original implementation applied a logistic transform to the integer
    flip counts and combined the result with parameters (a, b, k, n0)
    loaded from ``pivot_params.txt``.  While this yields values between
    ``a`` and ``a+b``, it tends to saturate when the flip counts span a
    moderate range and thus weakens any subsequent correlation analysis.

    To ensure that the simulated gauge fields carry a strong and
    interpretable dependence on the underlying flip counts we instead
    normalise the counts to the interval [0, π] and use that as the base
    angle.  Concretely, if ``fc`` is the array of non‑negative counts,
    ``theta_i = π * fc_i / max(fc)``.  When all flip counts are zero the
    returned array is identically zero.  The pivot parameters are parsed
    for compatibility but are not used in this normalisation.
    """
    # ensure floating point
    fc = flip_counts.astype(float)
    # normalise counts to [0, π]; avoid divide by zero when all counts are zero
    max_fc = fc.max() if fc.size > 0 else 0.0
    if max_fc == 0:
        theta = fc.copy()
    else:
        theta = (np.pi * fc) / max_fc
    return theta


def build_gauge_config(
    theta: np.ndarray,
    L: int,
    group: str,
    seed: int,
    kernel: np.ndarray | None = None,
    expected_links: int | None = None,
) -> np.ndarray:
    """
    Construct a gauge‑field configuration for the specified gauge group and
    lattice size.  ``theta`` should be a vector of length ``2*L*L`` giving
    the base angle per link.  ``seed`` is used to produce a deterministic
    perturbation via a pseudo‑random generator.  If a 1‑D ``kernel`` array
    is provided it will be tiled or truncated to match the number of links
    and used to modulate the per‑link angles (``phi_i = theta_i * kernel_i``).
    Returns an array of shape ``(L,L,2,d,d)`` where ``d`` is the group
    dimension (1 for U1, 2 for SU2, 3 for SU3).  The gauge matrices are
    constructed to be non‑trivial: for U1 the link is ``exp(i*phi)``;
    for SU2 a real 2×2 rotation matrix ``[[cos(phi), sin(phi)],
    [-sin(phi), cos(phi)]]``; and for SU3 a diagonal matrix
    ``diag(exp(i*phi), exp(-i*phi), 1)``.  These choices ensure the
    matrices reside in the appropriate special unitary groups and are
    not mere multiples of the identity.
    """
    dims = {"U1": 1, "SU2": 2, "SU3": 3}
    d = dims[group]

    # Ensure we have a kernel vector of the correct length
    n_links = expected_links if expected_links is not None else len(theta)
    if kernel is not None and len(kernel) > 0:
        # tile or truncate the kernel to length n_links
        if len(kernel) < n_links:
            reps = int(np.ceil(n_links / len(kernel)))
            kernel_vec = np.tile(kernel, reps)[:n_links]
        else:
            kernel_vec = kernel[:n_links]
    else:
        kernel_vec = np.ones(n_links, dtype=float)

    # deterministic perturbation
    rng = np.random.default_rng(seed=seed)
    eps = rng.normal(loc=0.0, scale=1.0, size=theta.shape)
    # apply small variation and kernel modulation
    theta_perturbed = theta * (1.0 + 0.01 * eps) * kernel_vec

    cfg = np.empty((L, L, 2, d, d), dtype=complex)
    for idx, phi in enumerate(theta_perturbed):
        x = idx // (2 * L)
        y = (idx // 2) % L
        mu = idx % 2
        # choose matrix representation based on gauge group
        if group == "U1":
            mat = np.array([[np.exp(1j * phi)]], dtype=complex)
        elif group == "SU2":
            # 2×2 special unitary rotation matrix
            c = np.cos(phi)
            s = np.sin(phi)
            mat = np.array([[c, s], [-s, c]], dtype=complex)
        elif group == "SU3":
            # 3×3 special unitary matrix constructed from an SU2 rotation on
            # the upper left 2×2 block and a U(1) phase on the third
            # diagonal element.  This representation couples the angle
            # ``phi`` more strongly to the matrix entries than a purely
            # diagonal form and therefore yields more pronounced Wilson‑loop
            # fluctuations.  The resulting matrix has determinant one and
            # resides in SU(3).
            c = np.cos(phi)
            s = np.sin(phi)
            mat = np.array(
                [
                    [c, s, 0],
                    [-s, c, 0],
                    [0, 0, np.exp(1j * phi)],
                ],
                dtype=complex,
            )
        else:
            raise ValueError(f"Unsupported gauge group: {group}")
        cfg[x, y, mu] = mat

    # assert shape correctness
    if cfg.shape != (L, L, 2, d, d):
        raise RuntimeError(
            f"Gauge config has wrong shape {cfg.shape}, expected {(L, L, 2, d, d)}"
        )
    return cfg


def main() -> None:
    p = argparse.ArgumentParser(description="Sample gauge‑field configurations")
    p.add_argument("--flip-counts", required=True, help="Path to flip_counts.npy")
    # kernel argument retained for compatibility but unused
    p.add_argument(
        "--kernel", required=False, help="Path to base kernel .npy (unused)", default=None
    )
    p.add_argument("--pivot-config", required=True, help="Path to pivot_params.txt")
    p.add_argument(
        "--lattice-size", "-L", type=int, default=6, help="Lattice side length"
    )
    p.add_argument(
        "--gauge-group", choices=["U1", "SU2", "SU3"], required=True
    )
    p.add_argument(
        "--trials", type=int, default=50, help="Number of configurations"
    )
    p.add_argument("--output-dir", required=True, help="Where to save .npy configs")
    args = p.parse_args()

    # prepare output folder
    os.makedirs(args.output_dir, exist_ok=True)

    # ---------------------------------------------------------------------------
    # Debug instrumentation: check for existence of a real kernel and report its
    # shape.  This does not affect the deterministic gauge‑field generation
    # below.  The repository root is resolved relative to this script's location.
    repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
    kp = os.path.join(repo_root, "kernel_builder", "kernel.npy")
    print("DEBUG: kernel path =", kp, "Exists?", os.path.exists(kp))
    K = np.load(kp) if os.path.exists(kp) else None
    print("DEBUG: loaded kernel shape:", None if K is None else K.shape)
    # ---------------------------------------------------------------------------

    # load inputs
    flip_counts = np.load(args.flip_counts)
    pivot_params = load_pivot_params(args.pivot_config)
    expected_links = 2 * args.lattice_size * args.lattice_size
    if flip_counts.size != expected_links:
        raise RuntimeError(
            f"flip_counts length {flip_counts.size} != 2*L^2 ({expected_links})"
        )

    # compute base theta per link
    base_theta = compute_theta(flip_counts, pivot_params)

    # prepare a kernel for modulation: use kernel loaded via instrumentation
    # (K defined above).  If K is None or empty, modulation defaults to unity.
    kernel_arr = K if K is not None else None

    # generate trials
    for t in range(args.trials):
        cfg = build_gauge_config(
            base_theta,
            args.lattice_size,
            args.gauge_group,
            seed=t,
            kernel=kernel_arr,
            expected_links=expected_links,
        )
        fname = f"{args.gauge_group}_cfg_{t:03d}.npy"
        fpath = os.path.join(args.output_dir, fname)
        np.save(fpath, cfg)
        print(f"Saved {fpath}")


if __name__ == "__main__":
    main()